/*
 * Copyright (C) 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "interpreter.h"
#include "memory_manager.h"

#include "core/cc/log.h"

#include <unistd.h>
#include <utility>
#include <vector>

namespace gapir {

namespace {

template<typename T> inline T sum2(T a, T b) { return a + b; }
template<typename T> inline T* sum2(T* a, T* b) {
  return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(a) + reinterpret_cast<uintptr_t>(b));
}

template<typename T>
inline bool sum(Stack& stack, uint32_t count) {
  T v = 0;
  for (uint32_t i = 0; i < count; i++) {
    v = sum2(v, stack.pop<T>());
  }
  stack.push(v);
  return stack.isValid();
}

} // anonymous namespace

Interpreter::Interpreter(const MemoryManager* memoryManager, uint32_t stackDepth,
                         ApiRequestCallback callback) :
        mMemoryManager(memoryManager),
        apiRequestCallback(std::move(callback)),
        mStack(stackDepth, mMemoryManager),
        mLabel(0) {
    registerBuiltin(PRINT_STACK_FUNCTION_ID, [](Stack* stack, bool) {
        stack->printStack();
        return true;
    });
}

void Interpreter::registerBuiltin(FunctionTable::Id id, FunctionTable::Function func) {
    mBuiltins.insert(id, func);
}

void Interpreter::setRendererFunctions(uint8_t api, FunctionTable* functionTable) {
    mRendererFunctions[api] = functionTable;
}

bool Interpreter::run(const std::pair<const uint32_t*, uint32_t>& instructions) {
    for (uint32_t i = 0; i < instructions.second; ++i) {
        if (!interpret(instructions.first[i])) {
            GAPID_WARNING(
                    "Interpreter stopped because of an interpretation error at opcode %u (%u)\n"
                    "Last reached label: %d\n",
                    i, instructions.first[i], mLabel);
            return false;
        }
    }
    return true;
}

BaseType Interpreter::extractType(uint32_t opcode) const {
    return BaseType((opcode & TYPE_MASK) >> TYPE_BIT_SHIFT);
}

uint32_t Interpreter::extract20bitData(uint32_t opcode) const {
    return opcode & DATA_MASK20;
}

uint32_t Interpreter::extract26bitData(uint32_t opcode) const {
    return opcode & DATA_MASK26;
}

bool Interpreter::registerApi(uint8_t api) {
    return apiRequestCallback(this, api);
}

bool Interpreter::call(uint32_t opcode) {
    auto id = opcode & FUNCTION_ID_MASK;
    auto api = (opcode & API_INDEX_MASK) >> API_BIT_SHIFT;
    auto func = mBuiltins.lookup(id);
    if (func == nullptr) {
        if (mRendererFunctions.count(api) > 0) {
            func = mRendererFunctions[api]->lookup(id);
        } else {
            if (apiRequestCallback(this, api)) {
                func = mRendererFunctions[api]->lookup(id);
            } else {
                GAPID_WARNING("Error setting up renderer functions for api: %u", api);
            }
        }
    }
    if (func == nullptr) {
        GAPID_WARNING("Invalid function id(%u), in api(%d)", id, api);
        return false;
    }
    if (!(*func)(&mStack, (opcode & PUSH_RETURN_MASK) != 0)) {
        GAPID_WARNING("Error raised when calling function with id: %u", id);
        return false;
    }
    return true;
}

bool Interpreter::pushI(uint32_t opcode) {
    BaseType type = extractType(opcode);
    if (!isValid(type)) {
        GAPID_WARNING("Error: pushI basic type invalid %u", type);
        return false;
    }
    Stack::BaseValue data = extract20bitData(opcode);
    switch (type) {
        // Sign extension for signed types
        case BaseType::Int32:
        case BaseType::Int64:
            if (data & 0x80000) {
                data |= 0xfffffffffff00000ULL;
            }
            break;
        // Shifting the value into the exponent for floating point types
        case BaseType::Float:
            data <<= 23;
            break;
        case BaseType::Double:
            data <<= 52;
            break;
        default:
            break;
    }
    mStack.pushValue(type, data);
    return mStack.isValid();
}

bool Interpreter::loadC(uint32_t opcode) {
    BaseType type = extractType(opcode);
    if (!isValid(type)) {
      GAPID_WARNING("Error: loadC basic type invalid %u", type);
      return false;
    }
    const void* address = mMemoryManager->constantToAbsolute(extract20bitData(opcode));
    if (!isConstantAddressForType(address, type)) {
      GAPID_WARNING("Error: loadC not constant address %p", address);
      return false;
    }
    mStack.pushFrom(type, address);
    return mStack.isValid();
}

bool Interpreter::loadV(uint32_t opcode) {
    BaseType type = extractType(opcode);
    if (!isValid(type)) {
      GAPID_WARNING("Error: loadV basic type invalid %u", type);
      return false;
    }
    const void* address = mMemoryManager->volatileToAbsolute(extract20bitData(opcode));
    if (!isVolatileAddressForType(address, type)) {
      GAPID_WARNING("Error: loadV not volatile address %p", address);
      return false;
    }
    mStack.pushFrom(type, address);
    return mStack.isValid();
}

bool Interpreter::load(uint32_t opcode) {
    BaseType type = extractType(opcode);
    if (!isValid(type)) {
      GAPID_WARNING("Error: load basic type invalid %u", type);
      return false;
    }
    const void* address = mStack.pop<const void*>();
    if (!isReadAddress(address)) {
      GAPID_WARNING("Error: load not readable address %p", address);
      return false;
    }
    mStack.pushFrom(type, address);
    return mStack.isValid();
}

bool Interpreter::pop(uint32_t opcode) {
    mStack.discard(extract26bitData(opcode));
    return mStack.isValid();
}

bool Interpreter::storeV(uint32_t opcode) {
    void* address = mMemoryManager->volatileToAbsolute(extract26bitData(opcode));
    if (!isVolatileAddressForType(address, mStack.getTopType())) {
      GAPID_WARNING("Error: storeV not volatile address %p", address);
      return false;
    }

    mStack.popTo(address);
    return mStack.isValid();
}

bool Interpreter::store() {
    void* address = mStack.pop<void*>();
    if (!isWriteAddress(address)) {
      GAPID_WARNING("Error: store not write address %p", address);
      return false;
    }
    mStack.popTo(address);
    return mStack.isValid();
}

bool Interpreter::resource(uint32_t opcode) {
    mStack.push<uint32_t>(extract26bitData(opcode));
    return this->call(Interpreter::RESOURCE_FUNCTION_ID);
}

bool Interpreter::post() {
    return this->call(Interpreter::POST_FUNCTION_ID);
}

bool Interpreter::copy(uint32_t opcode) {
    uint32_t count = extract26bitData(opcode);
    void* target = mStack.pop<void*>();
    const void* source = mStack.pop<const void*>();
    if (!isWriteAddress(target)) {
        GAPID_WARNING("Error: copy target is invalid %p %d", target, count);
        return false;
    }
    if (!isReadAddress(source)) {
        GAPID_WARNING("Error: copy source is invalid %p %d", target, count);
        return false;
    }
    if (source == nullptr) {
        GAPID_WARNING("Error: copy source address is null");
        return false;
    }
    if (target == nullptr) {
        GAPID_WARNING("Error: copy destination address is null");
        return false;
    }
    memcpy(target, source, count);
    return mStack.isValid();
}

bool Interpreter::clone(uint32_t opcode) {
    mStack.clone(extract26bitData(opcode));
    return mStack.isValid();
}

bool Interpreter::strcpy(uint32_t opcode) {
    uint32_t count = extract26bitData(opcode);
    char* target = mStack.pop<char*>();
    const char* source = mStack.pop<const char*>();
    // Requires that the whole count is available, even if source is shorter.
    if (!isWriteAddress(target)) {
        GAPID_WARNING("Error: copy target is invalid %p %d", target, count);
        return false;
    }
    if (!isReadAddress(source)) {
        GAPID_WARNING("Error: copy source is invalid %p %d", target, count);
        return false;
    }
    if (source == nullptr) {
        GAPID_WARNING("Error: strcpy source address is null");
        return false;
    }
    if (target == nullptr) {
        GAPID_WARNING("Error: strcpy destination address is null");
        return false;
    }
    uint32_t i;
    for (i = 0; i < count - 1; i++) {
        char c = source[i];
        if (c == 0) {
            break;
        }
        target[i] = c;
    }
    for (; i < count; i++) {
        target[i] = 0;
    }
    return mStack.isValid();
}

bool Interpreter::extend(uint32_t opcode) {
    uint32_t data = extract26bitData(opcode);
    auto type = mStack.getTopType();
    auto value = mStack.popBaseValue();
    switch (type) {
        // Masking out the mantissa end extending it with the new bits for floating point types
        case BaseType::Float: {
            value |= (data & 0x007fffffULL);
            break;
        }
        case BaseType::Double: {
            uint64_t exponent = value & 0xfff0000000000000ULL;
            value <<= 26;
            value |= data;
            value &= 0x000fffffffffffffULL;
            value |= exponent;
            break;
        }
        // Extending the value with 26 new LSB
        default: {
            value = (value << 26) | data;
            break;
        }
    }
    mStack.pushValue(type, value);
    return mStack.isValid();
}

bool Interpreter::add(uint32_t opcode) {
  uint32_t count = extract26bitData(opcode);
  if (count < 2) {
    return mStack.isValid();
  }
  auto type = mStack.getTopType();
  switch (type) {
      case BaseType::Int8:             return sum<int8_t>(mStack, count);
      case BaseType::Int16:            return sum<int16_t>(mStack, count);
      case BaseType::Int32:            return sum<int32_t>(mStack, count);
      case BaseType::Int64:            return sum<int64_t>(mStack, count);
      case BaseType::Uint8:            return sum<uint8_t>(mStack, count);
      case BaseType::Uint16:           return sum<uint16_t>(mStack, count);
      case BaseType::Uint32:           return sum<uint32_t>(mStack, count);
      case BaseType::Uint64:           return sum<uint64_t>(mStack, count);
      case BaseType::Float:            return sum<float>(mStack, count);
      case BaseType::Double:           return sum<double>(mStack, count);
      case BaseType::AbsolutePointer:  return sum<void*>(mStack, count);
      case BaseType::ConstantPointer:  return sum<void*>(mStack, count);
      default:
        GAPID_WARNING("Cannot add values of type %s", baseTypeName(type));
        return false;
  }
}

bool Interpreter::label(uint32_t opcode) {
    mLabel = extract26bitData(opcode);
    return mStack.isValid();
}

#define DEBUG_OPCODE(name, value) GAPID_VERBOSE(name)
#define DEBUG_OPCODE_26(name, value) GAPID_VERBOSE(name "(%#010x)", value & DATA_MASK26)
#define DEBUG_OPCODE_TY_20(name, value) GAPID_VERBOSE(name "(%#010x, %s)", value & DATA_MASK20, baseTypeName(extractType(value)))

bool Interpreter::interpret(uint32_t opcode) {
    InstructionCode code = static_cast<InstructionCode>(opcode >> OPCODE_BIT_SHIFT);
    switch (code) {
        case InstructionCode::CALL:
            DEBUG_OPCODE_26("CALL", opcode);
            return this->call(opcode);
        case InstructionCode::PUSH_I:
            DEBUG_OPCODE_TY_20("PUSH_I", opcode);
            return this->pushI(opcode);
        case InstructionCode::LOAD_C:
            DEBUG_OPCODE_TY_20("LOAD_C", opcode);
            return this->loadC(opcode);
        case InstructionCode::LOAD_V:
            DEBUG_OPCODE_TY_20("LOAD_V", opcode);
            return this->loadV(opcode);
        case InstructionCode::LOAD:
            DEBUG_OPCODE_TY_20("LOAD", opcode);
            return this->load(opcode);
        case InstructionCode::POP:
            DEBUG_OPCODE_26("POP", opcode);
            return this->pop(opcode);
        case InstructionCode::STORE_V:
            DEBUG_OPCODE_26("STORE_V", opcode);
            return this->storeV(opcode);
        case InstructionCode::STORE:
            DEBUG_OPCODE("STORE", opcode);
            return this->store();
        case InstructionCode::RESOURCE:
            DEBUG_OPCODE_26("RESOURCE", opcode);
            return this->resource(opcode);
        case InstructionCode::POST:
            DEBUG_OPCODE("POST", opcode);
            return this->post();
        case InstructionCode::COPY:
            DEBUG_OPCODE_26("COPY", opcode);
            return this->copy(opcode);
        case InstructionCode::CLONE:
            DEBUG_OPCODE_26("CLONE", opcode);
            return this->clone(opcode);
        case InstructionCode::STRCPY:
            DEBUG_OPCODE_26("STRCPY", opcode);
            return this->strcpy(opcode);
        case InstructionCode::EXTEND:
            DEBUG_OPCODE_26("EXTEND", opcode);
            return this->extend(opcode);
        case InstructionCode::ADD:
            DEBUG_OPCODE_26("ADD", opcode);
            return this->add(opcode);
        case InstructionCode::LABEL:
            DEBUG_OPCODE_26("LABEL", opcode);
            return this->label(opcode);
        default:
            GAPID_WARNING("Unknown opcode! %#010x", opcode);
            return false;
    }
}

#undef DEBUG_OPCODE

}  // namespace gapir
